import numpy as np
from scipy.optimize import brentq

def h(a, ε, h̄):
    """Disutility from labor effort
    
    Parameters
    ----------
    a : np.array, effort
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function

    Returns
    -------
    np.array
    """

    return h̄*(1 + 1/ε)**(-1) * a**(1 + 1/ε)

def h_inv(h_val, ε, h̄):
    """Inverse of disutility from labor effort
    
    Parameters
    ----------
    h_val : np.array, value of disutility of effort
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function

    Returns
    -------
    np.array
    """        
    return (1/h̄ * (1 + 1/ε) * h_val) ** (ε/(1 + ε))

def dh(a, ε, h̄):
    """Marginal disutility from labor effort
    
    Parameters
    ----------
    a : np.array, effort
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function

    Returns
    -------
    np.array
    """
    return h̄ * a**(1/ε)

def ddh(a, ε, h̄):
    """2nd derivative of disutility from labor effort
    
    Parameters
    ----------
    a : np.array, effort
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function

    Returns
    -------
    np.array
    """    
    return  1/ε * h̄ * a**(1/ε - 1)

def WorkerUtility_a(a, ε, h̄, p, τ):
    """Expected utility of the worker
    
    Parameters
    ----------
    a : np.array, effort
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    p : float, progressivity rate
    τ : float, the other tax parameter

    Returns
    -------
    np.array
    """        
    h_val = h(a, ε, h̄)
    beta = dh(a, ε, h̄)/(1 - p)
    
    w_base = a/(a*np.exp(beta) + 1 - a)
    w_bonus = np.exp(beta) * w_base
    
    return np.log((1 - τ)/(1 - p)) + (1 - p)*(a*np.log(w_bonus) + (1 - a)*np.log(w_base)) - h_val

def FindEffort(ε, h̄, p, job_type):
    """Find equilibrium labor effort given the job type.

    For performance-pay jobs, effort is found using the 2 stage approach
     - first use a grid search for a good initial guess
     - then use the optimality condition around the initial guess
    The reason for 2 stages is that the optimality condiiton will in general have more than 1 solution.
    First stage with grid search allows us to focus on the global maximum.
    
    Parameters
    ----------
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    p : float, rate of progressivity
    job type : either "pp" (performance-pay) or "fp" (fix-pay)
    Returns
    -------
    np.array
    """
    if job_type == 'pp':

        # !!! Hard upper bound of effort, to avoid numerical errors
        aV = np.linspace(1e-3, 0.5, 100)
        i_max = np.argmax(WorkerUtility_a(aV, ε,h̄, p, 0))
        
        def opt_cnd(a):
            β = dh(a, ε, h̄)/(1-p)
            MB = 1
            MC = a/(a + 1/(np.exp(β) - 1)) * (1 + a * (1 - a) * ddh(a, ε, h̄)/(1-p))
            return MB - MC
        
        if i_max <= 1:
            a = brentq(opt_cnd, 1e-6, aV[2])
        elif i_max >= len(aV) - 2:        
            if opt_cnd(1) > 0:
                a = 1
            else:
                a = brentq(opt_cnd, aV[-4], 1)
        else: 
            a = brentq(opt_cnd, aV[i_max - 2], aV[i_max + 2])        
        
        return a
    elif job_type == "fp":
        return ((1 - p)/h̄)**(ε/(1 + ε))
    else:
        return False


def FindElasticity(ε, h̄, p):
    """Find elasticity of effort at the pp job w.r.t. p
    
    Parameters
    ----------
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    p : float, rate of progressivity
    Returns
    -------
    np.array
    """
    a = FindEffort(ε,h̄, p, "pp")
    assert 1 > a > 0, "Error: Corner solution for effort"
    
    β = dh(a, ε, h̄)/(1 - p)
    dddh = h̄*(1/ε)*(1/ε - 1)*a**(1/ε - 2)
    num = (np.exp(β) - 1)/np.exp(β) * 1/a + β/a
    denom = (((2 - 3*a) * a * ddh(a, ε, h̄) + a**2 * (1 - a)*dddh)*(np.exp(β) - 1)**2/np.exp(β) * 1/(1 - p) 
                + ddh(a, ε, h̄)/(1-p))
    return num/denom

def Find_Var_log_y(ε, h̄, p, λ_θ, π, var_ratio, mean_ratio, Y_emp):
    """Find the implied variance of log earnings 

    Parameters
    ----------
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    p : float, rate of progressivity
    ...
    Returns
    -------
    float
    """    
    l_pp = FindEffort(ε, h̄, p, "pp")
    l_fp = FindEffort(ε, h̄, p, "fp")
    β = dh(l_pp, ε, h̄)/(1 - p)
    sig2_θ = 1/(var_ratio - 1)*β*l_pp*(1 - l_pp) - 1/λ_θ**2
    if sig2_θ < 0:
        return False
    
    denom = λ_θ/(λ_θ - 1)*(π*mean_ratio*l_pp + (1 - π)*l_fp)
    μ_fp = np.log(Y_emp/denom) - sig2_θ/2
    μ_pp = μ_fp + np.log(mean_ratio)

    y_bar1 = l_pp/(np.exp(β)*l_pp + 1 - l_pp)
    last_term = (μ_pp + np.log(y_bar1) + l_pp*β - μ_fp - np.log(l_fp))**2
    var = 1/λ_θ**2 + sig2_θ + π*l_pp*(1 - l_pp)*β**2 + π*(1-π)*last_term
    return var

def Find_τ(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):
    """Find the second parameter of the tax function

    Parameters
    ----------
    p : float, rate of progressivity
    G : float, gov spending
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    ...
    Returns
    -------
    float
    """
    l_pp = FindEffort(ε, h̄, p, "pp")
    l_fp = FindEffort(ε, h̄, p, "fp")
    β = dh(l_pp, ε, h̄)/(1 - p)
    
    Y = λ_θ/(λ_θ - 1)*np.exp(sig2_θ/2)*(π*l_pp*np.exp(μ_pp) + (1 - π)*l_fp*np.exp(μ_fp))
    # E[w^1-p] for each type of job
    w_bar1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    Ew1_pp = w_bar1**(1-p) * (l_pp*np.exp((1-p)*β) + 1 - l_pp)
    Ew_pp = λ_θ/(λ_θ - (1 - p)) * np.exp((1 - p)*μ_pp + (1 - p)**2*sig2_θ/2) * Ew1_pp
    Ew_fp = λ_θ/(λ_θ - (1 - p)) * np.exp((1 - p)*μ_fp + (1 - p)**2*sig2_θ/2) * l_fp**(1-p)
    
    τ = 1 - (1-p)*(Y - G)/(π*Ew_pp + (1 - π)*Ew_fp)
    return τ

def Find_τ2(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):
    """Find the second parameter of the tax function

    Parameters
    ----------
    p : float, rate of progressivity
    G : float, gov spending
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    s_pp : share of pp jobs
    ...
    Returns
    -------
    float
    """
    l_fp = FindEffort(ε, h̄, p, "fp")
    l_pp = FindEffort(ε_pp, h̄_pp, p, "pp")
    β = dh(l_pp, ε_pp, h̄_pp)/(1 - p)
    
    Y = λ_θ/(λ_θ - 1)*(s_pp*l_pp*np.exp(μ_pp + sig2_pp/2) + (1 - s_pp)*l_fp*np.exp(μ_fp + sig2_fp/2))
    # E[w^(1-p)] for each type of job
    w_bar1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    Ew1_pp = w_bar1**(1-p) * (l_pp*np.exp((1-p)*β) + 1 - l_pp)
    Ew_pp = λ_θ/(λ_θ - (1 - p)) * np.exp((1 - p)*μ_pp + (1 - p)**2*sig2_pp/2) * Ew1_pp
    Ew_fp = λ_θ/(λ_θ - (1 - p)) * np.exp((1 - p)*μ_fp + (1 - p)**2*sig2_fp/2) * l_fp**(1-p)
    
    τ = 1 - (1-p)*(Y - G)/(s_pp*Ew_pp + (1 - s_pp)*Ew_fp)
    return τ    

def TaxIncidence(p, τ, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp, p_hat):
    """Incidence of increasing progressivity by p_hat"""

    l_pp = FindEffort(ε, h̄, p, "pp")
    l_fp = FindEffort(ε, h̄, p, "fp")
    β = dh(l_pp, ε, h̄)/(1 - p)

    Y_fp = λ_θ/(λ_θ - 1) * np.exp(μ_fp + sig2_θ/2) * l_fp
    C_fp = (1 - τ)/(1 - p)*λ_θ/(λ_θ - (1 - p))*np.exp((1 - p)*μ_fp + (1 - p)**2*sig2_θ/2) * l_fp**(1-p)

    Y_pp = λ_θ/(λ_θ - 1) * np.exp(μ_pp + sig2_θ/2) * l_pp
    w_base1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    term = w_base1**(1-p) * (l_pp*np.exp(β*(1-p)) + 1 - l_pp)
    C_pp = (1 - τ)/(1 - p)*λ_θ/(λ_θ - (1 - p))*np.exp((1 - p)*μ_pp + (1 - p)**2*sig2_θ/2) * term

    # Y = π*Y_pp + (1 - π)*Y_fp
    C = π*C_pp + (1 - π)*C_fp

    ME_fp = (μ_fp + (1 - p)*sig2_θ + 1/(λ_θ - (1 - p)) + np.log(l_fp) - 1/(1 - p))*C_fp*p_hat
    WE_fp = (1/(1 - p) - μ_fp - 1/λ_θ - np.log(l_fp))*C*p_hat
    eps_l_fp = ε/(1 + ε)
    EB_fp = (Y_fp/(1 - p) - C_fp)*eps_l_fp*p_hat
    # TE_fp = ME_fp + WE_fp - EB_fp

    term_ME = l_pp*β*np.exp((1 - p)*β)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    ME_pp = (μ_pp + (1 - p)*sig2_θ + 1/(λ_θ - (1 - p)) + np.log(w_base1) + term_ME - 1/(1 - p))*C_pp*p_hat
    term_WE = - β*(l_pp - w_base1*np.exp(β)) 
    WE_pp = (1/(1 - p) - μ_pp - 1/λ_θ - np.log(w_base1) - l_pp*β - term_WE)*C*p_hat

    ME_WE_redist_pp = (μ_pp + (1 - p)*sig2_θ + 1/(λ_θ - (1 - p)) + np.log(w_base1) - 1/(1 - p))*C_pp*p_hat + (1/(1 - p) - μ_pp - 1/λ_θ - np.log(w_base1))*C*p_hat
    ME_WE_insure_pp = term_ME*C_pp*p_hat - l_pp*β*C*p_hat
    ME_WE_co_pp = - term_WE*C*p_hat

    eps_l_pp = FindElasticity(ε, h̄, p)
    term_EB = w_base1/l_pp + l_pp/(1-p) * (np.exp((1-p)*β) - 1)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_pp = (Y_pp/(1 - p) - term_EB*C_pp)*eps_l_pp*p_hat

    eps_l_pp = FindElasticity(ε, h̄, p)
    term_SE = w_base1/l_pp + l_pp/(1-p) * (np.exp((1-p)*β) - 1)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_SE_pp = (Y_pp/(1 - p) - term_SE*C_pp)*eps_l_pp*p_hat
    term_CO_CI = l_pp*np.exp(β)/(l_pp*np.exp(β) + 1 - l_pp) - l_pp*np.exp((1 - p)*β)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_CO_pp = - β * term_CO_CI * C_pp * p_hat
    EB_CI_pp = eps_l_pp * 1/ε * β * term_CO_CI * C_pp * p_hat
    EB_pp = EB_SE_pp + EB_CO_pp + EB_CI_pp

    return (ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, 
            ME_WE_redist_pp, ME_WE_insure_pp, ME_WE_co_pp,
            EB_SE_pp, EB_CO_pp, EB_CI_pp)

def TaxIncidence2(p, τ, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp, p_hat):
    """Incidence of increasing progressivity by p_hat"""

    l_fp = FindEffort(ε, h̄, p, "fp")
    l_pp = FindEffort(ε_pp, h̄_pp, p, "pp")    
    β = dh(l_pp, ε_pp, h̄_pp)/(1 - p)

    Y_fp = λ_θ/(λ_θ - 1) * np.exp(μ_fp + sig2_fp/2) * l_fp
    C_fp = (1 - τ)/(1 - p)*λ_θ/(λ_θ - (1 - p))*np.exp((1 - p)*μ_fp + (1 - p)**2*sig2_fp/2) * l_fp**(1-p)

    Y_pp = λ_θ/(λ_θ - 1) * np.exp(μ_pp + sig2_pp/2) * l_pp
    w_base1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    term = w_base1**(1-p) * (l_pp*np.exp(β*(1-p)) + 1 - l_pp)
    C_pp = (1 - τ)/(1 - p)*λ_θ/(λ_θ - (1 - p))*np.exp((1 - p)*μ_pp + (1 - p)**2*sig2_pp/2) * term

    # Y = s_pp*Y_pp + (1 - s_pp)*Y_fp
    C = s_pp*C_pp + (1 - s_pp)*C_fp

    ME_fp = (μ_fp + (1 - p)*sig2_fp + 1/(λ_θ - (1 - p)) + np.log(l_fp) - 1/(1 - p))*C_fp*p_hat
    WE_fp = (1/(1 - p) - μ_fp - 1/λ_θ - np.log(l_fp))*C*p_hat
    eps_l_fp = ε/(1 + ε)
    EB_fp = (Y_fp/(1 - p) - C_fp)*eps_l_fp*p_hat
    # TE_fp = ME_fp + WE_fp - EB_fp

    term_ME = l_pp*β*np.exp((1 - p)*β)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    ME_pp = (μ_pp + (1 - p)*sig2_pp + 1/(λ_θ - (1 - p)) + np.log(w_base1) + term_ME - 1/(1 - p))*C_pp*p_hat
    term_WE = - β*(l_pp - w_base1*np.exp(β)) 
    WE_pp = (1/(1 - p) - μ_pp - 1/λ_θ - np.log(w_base1) - l_pp*β - term_WE)*C*p_hat

    ME_WE_redist_pp = (μ_pp + (1 - p)*sig2_pp + 1/(λ_θ - (1 - p)) + np.log(w_base1) - 1/(1 - p))*C_pp*p_hat + (1/(1 - p) - μ_pp - 1/λ_θ - np.log(w_base1))*C*p_hat
    ME_WE_insure_pp = term_ME*C_pp*p_hat - l_pp*β*C*p_hat
    ME_WE_co_pp = - term_WE*C*p_hat

    eps_l_pp = FindElasticity(ε_pp, h̄_pp, p)
    term_EB = w_base1/l_pp + l_pp/(1-p) * (np.exp((1-p)*β) - 1)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_pp = (Y_pp/(1 - p) - term_EB*C_pp)*eps_l_pp*p_hat

    # eps_l_pp = FindElasticity(ε_pp, h̄_pp, p)
    term_SE = w_base1/l_pp + l_pp/(1-p) * (np.exp((1-p)*β) - 1)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_SE_pp = (Y_pp/(1 - p) - term_SE*C_pp)*eps_l_pp*p_hat
    term_CO_CI = l_pp*np.exp(β)/(l_pp*np.exp(β) + 1 - l_pp) - l_pp*np.exp((1 - p)*β)/(l_pp*np.exp((1-p)*β) + 1 - l_pp)
    EB_CO_pp = - β * term_CO_CI * C_pp * p_hat
    EB_CI_pp = eps_l_pp * 1/ε_pp * β * term_CO_CI * C_pp * p_hat
    EB_pp = EB_SE_pp + EB_CO_pp + EB_CI_pp

    return (ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, 
            ME_WE_redist_pp, ME_WE_insure_pp, ME_WE_co_pp,
            EB_SE_pp, EB_CO_pp, EB_CI_pp)            

def Find_p_opt(G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):
    """Find optimal progressivity iterating on the tax formula"""
    
    def opt_cnd(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):

        τ = Find_τ(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp)    

        ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, *rest = TaxIncidence(p, τ, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp, 1)
        return π*(ME_pp + WE_pp - EB_pp) + (1-π)*(ME_fp + WE_fp - EB_fp)

    p_opt = brentq(lambda p: opt_cnd(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp), -0.5, 0.9)
    return p_opt

def Find_p_opt2(G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):
    """Find optimal progressivity iterating on the tax formula"""
    
    def opt_cnd(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):

        τ = Find_τ2(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp)    

        ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, *rest = TaxIncidence2(p, τ, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp, 1)
        return s_pp*(ME_pp + WE_pp - EB_pp) + (1-s_pp)*(ME_fp + WE_fp - EB_fp)

    # p_opt = brentq(lambda p: opt_cnd(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, 
                                # sig2_pp, μ_pp, sig2_fp, μ_fp), -0.1, 0.6)
    p_opt = brentq(lambda p: opt_cnd(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, 
                                sig2_pp, μ_pp, sig2_fp, μ_fp), 0.12, 0.6)

    return p_opt


def FindSCPE(G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):
    """Find progressivity in SCPE iterating on the tax formula"""
    
    def opt_cnd(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):

        τ = Find_τ(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp)    

        (ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, ME_WE_redist_pp, ME_WE_insure_pp, ME_WE_co_pp,
            EB_SE_pp, EB_CO_pp, EB_CI_pp) = TaxIncidence(p, τ, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp, 1)
        
        return π*(ME_WE_redist_pp + ME_WE_insure_pp - EB_SE_pp) + (1-π)*(ME_fp + WE_fp - EB_fp)

    p_SCPE = brentq(lambda p: opt_cnd(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp), -0.5, 0.9)
    return p_SCPE

def FindSCPE2(G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):
    """Find progressivity in SCPE iterating on the tax formula"""
    
    def opt_cnd(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):

        τ = Find_τ2(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp)

        (ME_fp, WE_fp, EB_fp, ME_pp, WE_pp, EB_pp, ME_WE_redist_pp, ME_WE_insure_pp, ME_WE_co_pp,
            EB_SE_pp, EB_CO_pp, EB_CI_pp) = TaxIncidence2(p, τ, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp, 1)
        
        return s_pp*(ME_WE_redist_pp + ME_WE_insure_pp - EB_SE_pp) + (1-s_pp)*(ME_fp + WE_fp - EB_fp)

    p_SCPE = brentq(lambda p: opt_cnd(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, 
                                sig2_pp, μ_pp, sig2_fp, μ_fp), 0.12, 0.6)
    return p_SCPE    

def FindSWF(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp):
    """Find the value of the social welfare function

    Parameters
    ----------
    p : float, rate of progressivity
    G : float, gov spending
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    ...
    Returns
    -------
    float
    """
    τ = Find_τ(p, G, ε, h̄, λ_θ, π, sig2_θ, μ_pp, μ_fp)    
    
    l_pp = FindEffort(ε, h̄, p, "pp")
    l_fp = FindEffort(ε, h̄, p, "fp")
    β = dh(l_pp, ε, h̄)/(1 - p)

    w_base1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    SWF_pp = (np.log((1-τ)/(1-p)) + (1 - p)*(1/λ_θ + μ_pp) 
                  + (1 - p)*(np.log(w_base1) + l_pp*β) - h(l_pp, ε, h̄))
    SWF_fp = (np.log((1-τ)/(1-p)) + (1 - p)*(1/λ_θ + μ_fp) 
                  + (1 - p)*np.log(l_fp) - h(l_fp, ε, h̄))

    return π*SWF_pp + (1-π)*SWF_fp

def FindSWF2(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp):
    """Find the value of the social welfare function

    Parameters
    ----------
    p : float, rate of progressivity
    G : float, gov spending
    ε : float, Frisch elasticity
    h̄ : float, weight in the utility function
    ...
    Returns
    -------
    float
    """
    τ = Find_τ2(p, G, ε, h̄, ε_pp, h̄_pp, λ_θ, s_pp, sig2_pp, μ_pp, sig2_fp, μ_fp)
    
    l_fp = FindEffort(ε, h̄, p, "fp")
    l_pp = FindEffort(ε_pp, h̄_pp, p, "pp")
    β = dh(l_pp, ε_pp, h̄_pp)/(1 - p)

    w_base1 = l_pp/(l_pp*np.exp(β) + 1 - l_pp)
    SWF_pp = (np.log((1-τ)/(1-p)) + (1 - p)*(1/λ_θ + μ_pp) 
                  + (1 - p)*(np.log(w_base1) + l_pp*β) - h(l_pp, ε_pp, h̄_pp))
    SWF_fp = (np.log((1-τ)/(1-p)) + (1 - p)*(1/λ_θ + μ_fp) 
                  + (1 - p)*np.log(l_fp) - h(l_fp, ε, h̄))

    return s_pp*SWF_pp + (1-s_pp)*SWF_fp    

def Find_GeneralIncidence(linear_utility=False):
    """
    Returns a function which computes the tax incidence of a given tax reform for a single agent
    with an arbitrary utility starting from the arbitrary tax schedule, given the initial equilibrium.
    """
    import sympy as sym
    # non-linear utility, non-linear tax

    b, l, dv_yb, dv_y, ddh, dddh, X, θ, σ_y, σ_yb = sym.symbols('b l dv_yb dv_y ddh dddh X θ σ_y σ_yb')
    R_y, R_yb, r_y, r_yb, dr_y, dr_yb = sym.symbols('R_y R_yb r_y r_yb dr_y dr_yb')
    R̂_y, R̂_yb, r̂_y, r̂_yb, ŷ, b̂, Û, l̂, dv̂_yb, dv̂_y, X̂  = sym.symbols('R̂_y R̂_yb r̂_y r̂_yb ŷ b̂ Û l̂ dv̂_yb dv̂_y X̂')
    
    if linear_utility: # adjusted such that zero consumption does not yield error
        Eq1 = - ŷ - R̂_y/r_y + Û/dv_y - l*ddh/dv_y*l̂
        Eq2 = - (ŷ + b̂) - R̂_yb/r_yb + Û/dv_yb + (1 - l)*ddh/dv_yb*l̂
        Eq3 = - θ*l̂ + l̂*b + l*(ŷ + b̂) + (1 - l)*ŷ
        Eq4 = - dv̂_y/dv_y + r̂_y/r_y + (dr_y/r_y)*ŷ
        Eq5 = - dv̂_yb/dv_yb + r̂_yb/r_yb + (dr_yb/r_yb)*(ŷ + b̂)
        Eq6 = - X̂ + dv̂_y/dv_y**2 - dv̂_yb/dv_yb**2
        Eq7 = b̂ + l*(1-l)*ddh*X̂ + ((1 - 2*l)*ddh + l*(1-l)*dddh)*X*l̂        
        
        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4, Eq5, Eq6, Eq7], (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂), simplify=False)
        # sol_numpy = sym.lambdify([R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, 
                            # b, l, ddh, dddh, X, dv_y, dv_yb, θ], list(sol_sym.values()), 'numpy')

        ordered_sol_list = [sol_sym[k] for k in (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂)]
        sol_numpy = sym.lambdify([R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, 
                    b, l, ddh, dddh, σ_y, σ_yb, X, dv_y, dv_yb, θ], ordered_sol_list, 'numpy')


        sol_numpy.__doc__ = """
        Solves the general tax incidence under a linear utility.

        Input: (R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, b, l, ddh, dddh, X, dv_y, dv_yb, θ)

        Output: (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂)
        """

    else:
        Eq1 = - ŷ - R̂_y/r_y + Û/dv_y - l*ddh/dv_y*l̂
        Eq2 = - (ŷ + b̂) - R̂_yb/r_yb + Û/dv_yb + (1 - l)*ddh/dv_yb*l̂
        Eq3 = - θ*l̂ + l̂*b + l*(ŷ + b̂) + (1 - l)*ŷ
        Eq4 = - dv̂_y/dv_y + r̂_y/r_y - σ_y*R̂_y/R_y + (dr_y/r_y - σ_y*r_y/R_y)*ŷ
        Eq5 = - dv̂_yb/dv_yb + r̂_yb/r_yb - σ_yb*R̂_yb/R_yb + (dr_yb/r_yb - σ_yb*r_yb/R_yb)*(ŷ + b̂)
        Eq6 = - X̂ + dv̂_y/dv_y**2 - dv̂_yb/dv_yb**2
        Eq7 = b̂ + l*(1-l)*ddh*X̂ + ((1 - 2*l)*ddh + l*(1-l)*dddh)*X*l̂

        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4, Eq5, Eq6, Eq7], (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂), simplify=False)
        # sol_numpy = sym.lambdify([R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, 
                            # b, l, ddh, dddh, σ_y, σ_yb, X, dv_y, dv_yb, θ], list(sol_sym.values()), 'numpy')

        ordered_sol_list = [sol_sym[k] for k in (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂)]
        sol_numpy = sym.lambdify([R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, 
                    b, l, ddh, dddh, σ_y, σ_yb, X, dv_y, dv_yb, θ], ordered_sol_list, 'numpy')

        sol_numpy.__doc__ = """
        Solves the general tax incidence under an arbitrary utility. Careful, R_y and R_yb cannot be zero.

        Input: (R̂_y, R̂_yb, r̂_y, r̂_yb, R_y, R_yb, r_y, r_yb, dr_y, dr_yb, b, 
        l, ddh, dddh, σ_y, σ_yb, X, dv_y, dv_yb, θ)

        Output: (ŷ, b̂, Û, l̂, dv̂_y, dv̂_yb, X̂)
        """
    return sol_numpy

def Find_GeneralIncidence_fp(linear_utility=False):
    """
    Returns a function which computes the tax incidence of a given tax reform for a single agent
    with an arbitrary utility starting from the arbitrary tax schedule, given the initial equilibrium.
    """
    import sympy as sym
    # non-linear utility, non-linear tax

    l, dv, dh, ddh, θ, σ = sym.symbols('l dv dh ddh θ σ')
    R, r, dr = sym.symbols('R r dr')
    R̂, r̂, ŷ, Û, l̂, dv̂  = sym.symbols('R̂ r̂ ŷ Û l̂ dv̂')
    
    if linear_utility: # adjusted such that zero consumption does not yield error
        Eq1 = - ŷ - R̂/r + Û/dv + dh/dv*l̂
        Eq2 = - θ*l̂ + ŷ
        Eq3 = - dv̂/dv + r̂/r + (dr/r)*ŷ
        Eq4 = - l̂ + dv̂/dv * dh/ddh
        
        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4], (ŷ, l̂, Û, dv̂), simplify=False)
        sol_numpy = sym.lambdify([R̂, r̂, R, r, dr, l, dh, ddh, dv, θ],
                                 list(sol_sym.values()), 'numpy')

        sol_numpy.__doc__ = """
        Solves the general tax incidence under a linear utility.

        Input: (R̂, r̂, R, r, dr, l, dh, ddh, dv, θ)

        Output: (ŷ, l̂, Û, dv̂)
        """

    else:
        Eq1 = - ŷ - R̂/r + Û/dv + dh/dv*l̂
        Eq2 = - θ*l̂ + ŷ
        Eq3 = - dv̂/dv + r̂/r - σ*R̂/R + (dr/r - σ*r/R)*ŷ
        Eq4 = - l̂ + dv̂/dv * dh/ddh
        
        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4], (ŷ, l̂, Û, dv̂), simplify=False)
        sol_numpy = sym.lambdify([R̂, r̂, R, r, dr, l, dh, ddh, σ, dv, θ],
                                 list(sol_sym.values()), 'numpy')

        sol_numpy.__doc__ = """
        Solves the general tax incidence under an arbitrary utility. Careful, R cannot be zero.

        Input: (R̂, r̂, R, r, dr, l, dh, ddh, σ, dv, θ)

        Output: (ŷ, l̂, Û, dv̂)
        """
    return sol_numpy

def Find_GeneralIncidence2(linear_utility=False):
    """
    Returns a function which computes the tax incidence of a given tax reform for a single agent
    with an arbitrary utility starting from the arbitrary tax schedule, given the initial equilibrium. 
    
    Allows for separate taxation of bonuses.
    """
    import sympy as sym

    b, l, v_y1, v_yb1, v_yb2, ddh, dddh, θ, σ_y, σ_yb = sym.symbols('b l v_y1 v_yb1 v_yb2 ddh dddh θ σ_y σ_yb')
    R_y, R_yb, R_y1, R_y11, R_yb1, R_yb11, R_yb12, R_yb2, R_yb22, R_yb21 = sym.symbols('R_y R_yb R_y1 R_y11 R_yb1 R_yb11 R_yb12 R_yb2 R_yb22 R_yb21')
    term1, term2, term3 = sym.symbols('term1 term2 term3')
    R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2 = sym.symbols('R̂_y R̂_yb R̂_y1 R̂_yb1 R̂_yb2 ŷ b̂ Û l̂ v̂_y1 v̂_yb1 v̂_yb2')
    
    if linear_utility: # adjusted such that zero consumption does not yield error
        Eq1 = - ŷ - R̂_y/R_y1 + Û/v_y1 - l*ddh/v_y1*l̂
        Eq2 = - v_y1*ŷ - v_yb2*b̂ - R̂_yb + Û + (1 - l)*ddh*l̂
        Eq3 = - θ*l̂ + l̂*b + l*(ŷ + b̂) + (1 - l)*ŷ
        Eq4 = - v̂_y1/v_y1 + R̂_y1/R_y1 + R_y11/R_y1*ŷ
        Eq5 = - v̂_yb1/v_yb1 + R̂_yb1/R_yb1 + R_yb11/R_yb1*ŷ + R_yb12/R_yb1*b̂
        Eq6 = - v̂_yb2/v_yb2 + R̂_yb2/R_yb2 + R_yb22/R_yb2*b̂ + R_yb21/R_yb2*ŷ
        # Eq7 = - term1 + (ddh + l*dddh)*((1 - l + l*v_yb1/v_y1)/v_yb2 - 1/v_y1) + (v_yb1/v_y1 - 1)*l*ddh/v_yb2
        Eq7 = - term2 + v̂_yb1/v_yb1 - v̂_y1/v_y1 - v̂_yb2/v_yb2
        Eq8 = - term3 - l*(1-l)*ddh/v_yb2**2*v̂_yb2 + l**2*v_yb1/v_y1*ddh/v_yb2*term2 + l*ddh*v̂_y1/v_y1**2
        Eq9 = b̂ + term1*l̂ + term3
        
        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4, Eq5, Eq6, Eq7, Eq8, Eq9], 
                (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3), simplify=False)

        # sol_numpy = sym.lambdify([R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
        #                           R_y, R_yb, R_y1, R_y11, R_yb1, 
        #                           R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
        #                           b, l, ddh, dddh, v_y1, v_yb1, v_yb2, θ, term1], 
        #                           list(sol_sym.values()), 'numpy')

        ordered_sol_list = [sol_sym[k] for k in (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3)]
        sol_numpy = sym.lambdify([R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
                                  R_y, R_yb, R_y1, R_y11, R_yb1, 
                                  R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
                                  b, l, ddh, dddh, v_y1, v_yb1, v_yb2, θ, term1], 
                                  ordered_sol_list, 'numpy')


        sol_numpy.__doc__ = """
        Solves the general tax incidence under a linear utility.

        Input: (R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
                R_y, R_yb, R_y1, R_y11, R_yb1, 
                R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
                b, l, ddh, dddh, v_y1, v_yb1, v_yb2, θ, term1)

        Output: (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3)
        """

    else:
        # assert 0 == 1, "Error: this case not ready yet"
        Eq1 = - ŷ - R̂_y/R_y1 + Û/v_y1 - l*ddh/v_y1*l̂
        Eq2 = - b̂ - R_yb1/R_yb2*ŷ - R̂_yb/R_yb2 + Û/v_yb2 + (1 - l)*ddh/v_yb2*l̂
        Eq3 = - Û + (l/v_yb2 + (1 - l*R_yb1/R_yb2)/v_y1)**(-1) * (l*R̂_yb/R_yb2 + (1 - l*R_yb1/R_yb2)*R̂_y/R_y1)
        Eq4 = - v̂_y1/v_y1 + R̂_y1/R_y1 + R_y11/R_y1*ŷ - σ_y*(R̂_y/R_y + R_y1/R_y*ŷ)
        Eq5 = - v̂_yb1/v_yb1 + R̂_yb1/R_yb1 + R_yb11/R_yb1*ŷ + R_yb12/R_yb1*b̂ - σ_yb*(R̂_yb/R_yb + R_yb1/R_yb*ŷ + R_yb2/R_yb*b̂)        
        Eq6 = - v̂_yb2/v_yb2 + R̂_yb2/R_yb2 + R_yb22/R_yb2*b̂ + R_yb21/R_yb2*ŷ - σ_yb*(R̂_yb/R_yb + R_yb1/R_yb*ŷ + R_yb2/R_yb*b̂)
        # Eq7 = - term1 + (ddh + l*dddh)*((1 - l + l*v_yb1/v_y1)/v_yb2 - 1/v_y1) + (v_yb1/v_y1 - 1)*l*ddh/v_yb2
        Eq7 = - term2 + v̂_yb1/v_yb1 - v̂_y1/v_y1 - v̂_yb2/v_yb2
        Eq8 = - term3 - l*(1-l)*ddh/v_yb2**2*v̂_yb2 + l**2*v_yb1/v_y1*ddh/v_yb2*term2 + l*ddh*v̂_y1/v_y1**2
        Eq9 = b̂ + term1*l̂ + term3

        sol_sym = sym.solve([Eq1, Eq2, Eq3, Eq4, Eq5, Eq6, Eq7, Eq8, Eq9], 
                            (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3), simplify=False)
        
        # sol_numpy = sym.lambdify([R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
                                #   R_y, R_yb, R_y1, R_y11, R_yb1, 
                                #   R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
                                #   b, l, ddh, dddh, σ_y, σ_yb, v_y1, v_yb1, v_yb2, θ, term1], 
                                #   list(sol_sym.values()), 'numpy')

        ordered_sol_list = [sol_sym[k] for k in (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3)]

        sol_numpy = sym.lambdify([R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
                                  R_y, R_yb, R_y1, R_y11, R_yb1, 
                                  R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
                                  b, l, ddh, dddh, σ_y, σ_yb, v_y1, v_yb1, v_yb2, θ, term1], 
                                  ordered_sol_list, 'numpy')
        
        sol_numpy.__doc__ = """
        Solves the general tax incidence under an arbitrary utility. Careful, R_y and R_yb cannot be zero.

        Input: (R̂_y, R̂_yb, R̂_y1, R̂_yb1, R̂_yb2, 
                R_y, R_yb, R_y1, R_y11, R_yb1, 
                R_yb11, R_yb12, R_yb2, R_yb22, R_yb21,
                b, l, ddh, dddh, σ_y, σ_yb, v_y1, v_yb1, v_yb2, θ, term1)

        Output: (ŷ, b̂, Û, l̂, v̂_y1, v̂_yb1, v̂_yb2, term2, term3)
        """
    return sol_numpy

def Find_Var_log_z(M, π, ε, s_pp, λ_θ, p_emp, σ2_fp):
    """Calculate cross-sectional variance of log-earnings given the value of σ2_fp"""

    m1, m2, m3, m4 = M
    
    σ2_β = m2*(m1 - 1)*(σ2_fp + 1/λ_θ**2)
    β = (σ2_β/(π*(1 - π)))**0.5
    ε_pp = σ2_β*(np.exp(β) - 1)/β
    ρ = (1 + ε)/ε * ε_pp/(1 + ε_pp)
    h̄_pp = β*(1 - p_emp)/π**(1/ε_pp)    
    l_fp = (1 - p_emp)**(ε/(1 + ε))    
    μ_dif = np.log(m3 * l_fp/π) - (1 - m2)/m2 * σ2_β/2
    base_pay1 = π/(1 - π + π*np.exp(β))
    Var_log_z = ( (m1*s_pp + 1 - s_pp)*(σ2_fp + 1/λ_θ**2) +
           s_pp*(1-s_pp)*(μ_dif + np.log(base_pay1) + π*β - np.log(l_fp))**2 )
    
    return Var_log_z

def CalibrateModel(M, π, ε, s_pp, λ_θ, p_emp, Y_emp, G_emp):
    """Calibrate the model given data moments and externally chosen parameters"""
    
    m1, m2, m3, m4 = M
        
    from scipy.optimize import brentq
    σ2_fp = brentq(lambda x: Find_Var_log_z(M, π, ε, s_pp, λ_θ, p_emp, x) - m4, 0.1, 0.6)
    
    σ2_β = m2*(m1 - 1)*(σ2_fp + 1/λ_θ**2)
    σ2_pp = (1 - m2)/m2 * σ2_β + σ2_fp    
    β = (σ2_β/(π*(1 - π)))**0.5
    ε_pp = σ2_β*(np.exp(β) - 1)/β
    ρ = (1 + ε)/ε * ε_pp/(1 + ε_pp)
    h̄_pp = β*(1 - p_emp)/π**(1/ε_pp)    
    l_fp = (1 - p_emp)**(ε/(1 + ε))    
    μ_dif = np.log(m3 * l_fp/π) - (1 - m2)/m2 * σ2_β/2
    denom = ( λ_θ/(λ_θ - 1)*(s_pp*np.exp(μ_dif + σ2_pp/2)*π 
             + (1 - s_pp)*l_fp*np.exp(σ2_fp/2)) )
    μ_fp = np.log(Y_emp/denom)
    μ_pp = μ_fp + μ_dif
    
    return G_emp, ε, 1, ε_pp, h̄_pp, λ_θ, s_pp, σ2_pp, μ_pp, σ2_fp, μ_fp